import os
import pickle
from collections import defaultdict

import numpy as np

import transformers

import gym
import wrappers as w

import absl.app
import absl.flags
from flax.training.early_stopping import EarlyStopping
from flaxmodels.flaxmodels.lstm.lstm import LSTMRewardModel
from flaxmodels.flaxmodels.gpt2.trajectory_gpt2 import TransRewardModel

from .sampler import TrajSampler
from .jax_utils import batch_to_jax
import JaxPref.reward_transform as r_tf
from .model import FullyConnectedQFunction
from viskit.logging import logger, setup_logger
from .MR import MR
from .replay_buffer import get_d4rl_dataset, index_batch
from .NMR import NMR
from .PrefTransformer import PrefTransformer
from .utils import Timer, define_flags_with_default, set_random_seed, get_user_flags, prefix_metrics, WandBLogger, save_pickle

from ml_collections import config_flags

os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.15'

FLAGS_DEF = define_flags_with_default(
    env='halfcheetah-medium-v2',
    model_type='MLP',
    max_traj_length=1000,
    seed=42,
    data_seed=42,
    save_model=True,
    batch_size=64,
    early_stop=False,
    min_delta=1e-3,
    patience=10,

    reward_scale=1.0,
    reward_bias=0.0,
    clip_action=0.999,

    reward_arch='256-256',
    orthogonal_init=False,
    activations='relu',
    activation_final='none',
    training=True,

    n_epochs=2000,
    eval_period=5,

    data_dir='./human_label',
    num_query=1000,
    query_len=25,
    skip_flag=0,
    balance=False,
    topk=10,
    window=2,
    use_human_label=False,
    feedback_random=False,
    feedback_uniform=False,
    enable_bootstrap=False,

    comment='',
    
    robosuite=False,
    robosuite_dataset_type="ph",
    robosuite_dataset_path='./data',
    robosuite_max_episode_steps=500,

    reward=MR.get_default_config(),
    transformer=PrefTransformer.get_default_config(),
    lstm=NMR.get_default_config(),
    logging=WandBLogger.get_default_config(),
)

config_flags.DEFINE_config_file(
    'config',
    'default.py',
    'File path to the training hyperparameter configuration.',
    lock_config=False)

def load_mt_dataset(FLAGS):
    print('metaworkd task')
    import metaworld
    from gym import wrappers
    dataset_name = FLAGS.env.split('_')[1]
    ml1 = metaworld.MT1(dataset_name, seed=1337)  # Construct the benchmark, sampling tasks
    gym_env = ml1.train_classes[dataset_name]()  # Create an environment with task
    gym_env = wrappers.TimeLimit(gym_env, 500)
    gym_env.train_tasks = ml1.train_tasks
    task = ml1.train_tasks[0]
    gym_env.set_task(task)
    gym_env._freeze_rand_vec = False
    dataset = np.load(
        '/mnt/data/' + dataset_name + '/data_randgoal_08_50_08_batch.npy', allow_pickle=True).tolist()
    print('dataset has loaded...')
    
    return dataset, gym_env

def main(_):
    FLAGS = absl.flags.FLAGS
    variant = get_user_flags(FLAGS, FLAGS_DEF)
    model_save_dir = './saved_model'
    set_random_seed(FLAGS.seed)
    dataset, gym_env = load_mt_dataset(FLAGS)
    label_type = 0
    dataset['actions'] = np.clip(dataset['actions'], -FLAGS.clip_action, FLAGS.clip_action)
    
    env = FLAGS.env
    observation_dim = gym_env.observation_space.shape[0]
    action_dim = gym_env.action_space.shape[0]

    reward_model = None
    pref_dataset = None
    relabeled_dataset = None
    for query_index in range(10000):
        pref_dataset, relabeled_dataset = r_tf.get_queries_from_multi(
            gym_env, dataset, relabeled_dataset, FLAGS.num_query, FLAGS.query_len, reward_model, pref_dataset, query_index,
            label_type=label_type, balance=FLAGS.balance)

        data_size = pref_dataset["observations"].shape[0]
        np.save(f'{FLAGS.env}_traj1_{str(FLAGS.seed)}.npy', pref_dataset['traj return_1'])
        np.save(f'{FLAGS.env}_traj2_{str(FLAGS.seed)}.npy', pref_dataset['traj return_2'])
        print('current pref data size: ', data_size)
        interval = int(data_size / FLAGS.batch_size) + 1
        total_epochs = FLAGS.n_epochs
        config = transformers.GPT2Config(
            **FLAGS.transformer
        )
        config.warmup_steps = int(total_epochs * 0.1 * interval)
        config.total_steps = total_epochs * interval

        trans = TransRewardModel(config=config, observation_dim=observation_dim, action_dim=action_dim, activation=FLAGS.activations, activation_final=FLAGS.activation_final)
        reward_model = PrefTransformer(config, trans)
        train_loss = "reward/trans_loss"

        for epoch in range(FLAGS.n_epochs + 1):
            if epoch % 100 == 0:
                print('epoch: ', epoch, FLAGS.n_epochs)
            metrics = defaultdict(list)
            metrics['epoch'] = epoch
            if epoch:
                shuffled_idx = np.random.permutation(pref_dataset["observations"].shape[0])
                for i in range(interval):
                    start_pt = i * FLAGS.batch_size
                    end_pt = min((i + 1) * FLAGS.batch_size, pref_dataset["observations"].shape[0])
                    with Timer() as train_timer:
                        batch = batch_to_jax(index_batch(pref_dataset, shuffled_idx[start_pt:end_pt]))
                        for key, val in prefix_metrics(reward_model.train(batch), 'reward').items():
                            metrics[key].append(val)
                metrics['train_time'] = train_timer()
            else:
                metrics[train_loss] = [float(FLAGS.query_len)]
        
        if data_size >= FLAGS.num_query + 1:
            break
        
    if FLAGS.save_model:
        save_data = {'reward_model': reward_model, 'variant': variant, 'epoch': epoch}
        save_pickle(save_data, f'model_{FLAGS.env}_iter_{str(query_index)}.pkl', model_save_dir)
        print('save done...   epoch: ', epoch)
    
if __name__ == '__main__':
    absl.app.run(main)
